{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "df8313a0",
   "metadata": {},
   "source": [
    "# Chapter 9: A Line-by-Line Implementation of Attention and Transformer\n",
    "\n",
    "This chapter covers\n",
    "\n",
    "* The functions of encoders and decoders in Transformers\n",
    "* How the attention mechanism uses query, key, and value to assign weights to elements in a sequence \n",
    "* Building and training a Transformer from scratch to translate English to French\n",
    "* Using the trained Transformer to translate an English phrase into French\n",
    "\n",
    "Transformers are advanced deep learning models that excel in handling sequence-to-sequence prediction challenges, outperforming older models like recurrent neural networks (RNNs) and convolutional neural networks (CNNs). Their strength lies in effectively understanding the relationships between elements in input and output sequences over long distances. Unlike RNNs, transformers are capable of parallel training, significantly cutting down training times and enabling the handling of vast datasets. This transformative architecture has been pivotal in the development of large language models (LLMs) like ChatGPT, BERT, and T5, marking a significant milestone in AI progress.\n",
    "\n",
    "Prior to the introduction of Transformers in the 2017 paper Attention Is All You Need,  natural language processing (NLP) and similar tasks primarily relied on RNNs, including long short-term memory (LSTM) models. RNNs, however, process information sequentially, limiting their speed due to the inability to train in parallel and struggling with maintaining information about earlier parts of a sequence, thus failing to capture long-term dependencies.\n",
    "\n",
    "The revolutionary aspect of the transformer architecture is its attention mechanism. This mechanism assesses the relationship between words in a sequence by assigning weights, determining how closely words are related based on the training data. This enables models like ChatGPT to comprehend relationships between words, thus understanding human language more effectively. The non-sequential processing of inputs allows for parallel training, reducing training time and facilitating the use of large datasets, thereby powering the rise of knowledgeable LLMs and the current surge in AI advancements.\n",
    "\n",
    "In this chapter, we will delve into building a Transformer from the ground up, based on the paper Attention Is All You Need, to translate English into French. We'll explore the inner workings of the self-attention mechanism, including the roles of query, key, and value vectors, and the computation of scaled dot product attention (SDPA). We'll construct an encoder layer by integrating layer normalization and residual connection into a multi-head attention layer and combining it with a feed-forward layer, and then stack six of these encoder layers to form the encoder. Similarly, we'll develop a decoder in the Transformer and learn to generate French translations one token at a time, in an autoregressive manner, from the encoder's output.\n",
    "\n",
    "Finally, we’ll train our model on a dataset containing over 47,000 English-to-French translations. The trained model can translate common English phrases accurately as if you are using Google Translate for the task."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97468550",
   "metadata": {},
   "source": [
    "# 1\tIntroduction to Transformers and Attention\n",
    "## 1.1\tWhat is attention?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7da22ef8",
   "metadata": {},
   "source": [
    "## 1.2\tThe transformer architecture"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ee4d965",
   "metadata": {},
   "source": [
    "# 2. Word Embedding and Positional Encoding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d30d2dc3",
   "metadata": {},
   "source": [
    "## 2.1. Word Tokenization\n",
    "First go to https://gattonweb.uky.edu/faculty/lium/gai/en2fr.zip to download zip file that contains the 47,000 English to French translations that I collected from various sources. Unzip the file and place en2fr.csv in the folder /files/ on your computer. We'll load the data and take a look as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aafd4b37",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "there are 47173 examples in the training data\n",
      "How are you?\n",
      "Comment êtes-vous?\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df=pd.read_csv(\"files/en2fr.csv\")\n",
    "num_examples=len(df)\n",
    "print(f\"there are {num_examples} examples in the training data\")\n",
    "print(df.iloc[30856][\"en\"])\n",
    "print(df.iloc[30856][\"fr\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c88c9df9",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a06a33bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['i</w>', 'don</w>', \"'t</w>\", 'speak</w>', 'fr', 'ench</w>', '.</w>']\n",
      "['je</w>', 'ne</w>', 'parle</w>', 'pas</w>', 'franc', 'ais</w>', '.</w>']\n",
      "['how</w>', 'are</w>', 'you</w>', '?</w>']\n",
      "['comment</w>', 'et', 'es-vous</w>', '?</w>']\n"
     ]
    }
   ],
   "source": [
    "from transformers import XLMTokenizer\n",
    "\n",
    "tokenizer = XLMTokenizer.from_pretrained(\"xlm-clm-enfr-1024\")\n",
    "\n",
    "tokenized_en=tokenizer.tokenize(\"I don't speak French.\")\n",
    "print(tokenized_en)\n",
    "tokenized_fr=tokenizer.tokenize(\"Je ne parle pas français.\")\n",
    "print(tokenized_fr)\n",
    "print(tokenizer.tokenize(\"How are you?\"))\n",
    "print(tokenizer.tokenize(\"Comment êtes-vous?\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7affb028",
   "metadata": {},
   "outputs": [],
   "source": [
    "# build dictionaries\n",
    "from collections import Counter\n",
    "\n",
    "en=df[\"en\"].tolist()\n",
    "\n",
    "en_tokens=[[\"BOS\"]+tokenizer.tokenize(x)+[\"EOS\"] for x in en]        \n",
    "PAD=0\n",
    "UNK=1\n",
    "# apply to English \n",
    "word_count=Counter()\n",
    "for sentence in en_tokens:\n",
    "    for word in sentence:\n",
    "        word_count[word]+=1\n",
    "frequency=word_count.most_common(50000)        \n",
    "total_en_words=len(frequency)+2\n",
    "en_word_dict={w[0]:idx+2 for idx,w in enumerate(frequency)}\n",
    "en_word_dict[\"PAD\"]=PAD\n",
    "en_word_dict[\"UNK\"]=UNK\n",
    "# another dictionary to map numbers to tokens\n",
    "en_idx_dict={v:k for k,v in en_word_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0d0abaf5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[15, 100, 38, 377, 476, 574, 5]\n"
     ]
    }
   ],
   "source": [
    "enidx=[en_word_dict.get(i,UNK) for i in tokenized_en]   \n",
    "print(enidx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "02572e98",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['i</w>', 'don</w>', \"'t</w>\", 'speak</w>', 'fr', 'ench</w>', '.</w>']\n",
      "i don't speak french. \n"
     ]
    }
   ],
   "source": [
    "entokens=[en_idx_dict.get(i,\"UNK\") for i in enidx]   \n",
    "print(entokens)\n",
    "en_phrase=\"\".join(entokens)\n",
    "en_phrase=en_phrase.replace(\"</w>\",\" \")\n",
    "for x in '''?:;.,'(\"-!&)%''':\n",
    "    en_phrase=en_phrase.replace(f\" {x}\",f\"{x}\")   \n",
    "print(en_phrase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c94fa3cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[157, 17, 22, 26]\n",
      "['how</w>', 'are</w>', 'you</w>', '?</w>']\n",
      "how are you? \n"
     ]
    }
   ],
   "source": [
    "# exercise 9.1\n",
    "tokens=['how</w>', 'are</w>', 'you</w>', '?</w>']\n",
    "indexes=[en_word_dict.get(i,UNK) for i in tokens]   \n",
    "print(indexes)\n",
    "tokens=[en_idx_dict.get(i,\"UNK\") for i in indexes]   \n",
    "print(tokens)\n",
    "phrase=\"\".join(tokens)\n",
    "phrase=phrase.replace(\"</w>\",\" \")\n",
    "for x in '''?:;.,'(\"-!&)%''':\n",
    "    phrase=phrase.replace(f\" {x}\",f\"{x}\")   \n",
    "print(phrase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1f2bd162",
   "metadata": {},
   "outputs": [],
   "source": [
    "# do the same for French phrases\n",
    "fr=df[\"fr\"].tolist()       \n",
    "fr_tokens=[[\"BOS\"]+tokenizer.tokenize(x)+[\"EOS\"] for x in fr] \n",
    "word_count=Counter()\n",
    "for sentence in fr_tokens:\n",
    "    for word in sentence:\n",
    "        word_count[word]+=1\n",
    "frequency=word_count.most_common(50000)        \n",
    "total_fr_words=len(frequency)+2\n",
    "fr_word_dict={w[0]:idx+2 for idx,w in enumerate(frequency)}\n",
    "fr_word_dict[\"PAD\"]=PAD\n",
    "fr_word_dict[\"UNK\"]=UNK\n",
    "fr_idx_dict={v:k for k,v in fr_word_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e4e843fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[28, 40, 231, 32, 726, 370, 4]\n"
     ]
    }
   ],
   "source": [
    "fridx=[fr_word_dict.get(i,UNK) for i in tokenized_fr]   \n",
    "print(fridx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e8a78dd5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['je</w>', 'ne</w>', 'parle</w>', 'pas</w>', 'franc', 'ais</w>', '.</w>']\n",
      "je ne parle pas francais. \n"
     ]
    }
   ],
   "source": [
    "frtokens=[fr_idx_dict.get(i,\"UNK\") for i in fridx]   \n",
    "print(frtokens)\n",
    "fr_phrase=\"\".join(frtokens)\n",
    "fr_phrase=fr_phrase.replace(\"</w>\",\" \")\n",
    "for x in '''?:;.,'(\"-!&)%''':\n",
    "    fr_phrase=fr_phrase.replace(f\" {x}\",f\"{x}\")  \n",
    "print(fr_phrase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9ec5627f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[452, 61, 742, 30]\n",
      "['comment</w>', 'et', 'es-vous</w>', '?</w>']\n",
      "comment etes-vous? \n"
     ]
    }
   ],
   "source": [
    "# exercise 9.2\n",
    "tokens=['comment</w>', 'et', 'es-vous</w>', '?</w>']\n",
    "indexes=[fr_word_dict.get(i,UNK) for i in tokens]   \n",
    "print(indexes)\n",
    "tokens=[fr_idx_dict.get(i,\"UNK\") for i in indexes]   \n",
    "print(tokens)\n",
    "phrase=\"\".join(tokens)\n",
    "phrase=phrase.replace(\"</w>\",\" \")\n",
    "for x in '''?:;.,'(\"-!&)%''':\n",
    "    phrase=phrase.replace(f\" {x}\",f\"{x}\")   \n",
    "print(phrase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f77812fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open(\"files/dict.p\",\"wb\") as fb:\n",
    "    pickle.dump((en_word_dict,en_idx_dict,\n",
    "                 fr_word_dict,fr_idx_dict),fb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f214dff",
   "metadata": {},
   "source": [
    "## 2.2. Sequence Padding and Batch Creation\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "20480c61",
   "metadata": {},
   "outputs": [],
   "source": [
    "out_en_ids=[[en_word_dict.get(w,1) for w in s] for s in en_tokens]\n",
    "out_fr_ids=[[fr_word_dict.get(w,1) for w in s] for s in fr_tokens]\n",
    "sorted_ids=sorted(range(len(out_en_ids)),\n",
    "                  key=lambda x:len(out_en_ids[x]))\n",
    "out_en_ids=[out_en_ids[x] for x in sorted_ids]\n",
    "out_fr_ids=[out_fr_ids[x] for x in sorted_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "91845c48",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "batch_size=128\n",
    "idx_list=np.arange(0,len(en_tokens),batch_size)\n",
    "np.random.shuffle(idx_list)\n",
    "\n",
    "batch_indexs=[]\n",
    "for idx in idx_list:\n",
    "    batch_indexs.append(np.arange(idx,min(len(en_tokens),\n",
    "                                          idx+batch_size)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4bec238e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def seq_padding(X, padding=0):\n",
    "    L = [len(x) for x in X]\n",
    "    ML = max(L)\n",
    "    padded_seq = np.array([np.concatenate([x, [padding] * (ML - len(x))])\n",
    "        if len(x) < ML else x for x in X])\n",
    "    return padded_seq"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73421610",
   "metadata": {},
   "source": [
    "The following class is defined in the local module ch09util.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1f9ac0b",
   "metadata": {},
   "source": [
    "```Python\n",
    "import torch\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# define the Batch class\n",
    "class Batch:\n",
    "    def __init__(self, src, trg=None, pad=0):\n",
    "        src = torch.from_numpy(src).to(DEVICE).long()\n",
    "        trg = torch.from_numpy(trg).to(DEVICE).long()\n",
    "        self.src = src\n",
    "        self.src_mask = (src != pad).unsqueeze(-2)\n",
    "        if trg is not None:\n",
    "            self.trg = trg[:, :-1]\n",
    "            self.trg_y = trg[:, 1:]\n",
    "            self.trg_mask = make_std_mask(self.trg, pad)\n",
    "            self.ntokens = (self.trg_y != pad).data.sum()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b38389fa",
   "metadata": {},
   "source": [
    "```python\n",
    "import numpy as np\n",
    "def subsequent_mask(size):\n",
    "    attn_shape = (1, size, size)\n",
    "    subsequent_mask = np.triu(np.ones(attn_shape),\n",
    "                              k=1).astype('uint8')\n",
    "    output = torch.from_numpy(subsequent_mask) == 0\n",
    "    return output\n",
    "\n",
    "def make_std_mask(tgt, pad):\n",
    "    tgt_mask = (tgt != pad).unsqueeze(-2)\n",
    "    output = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)\n",
    "    return output \n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "6e37251c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.ch09util import Batch\n",
    "\n",
    "batches=[]\n",
    "for b in batch_indexs:\n",
    "    batch_en=[out_en_ids[x] for x in b]\n",
    "    batch_fr=[out_fr_ids[x] for x in b]\n",
    "    batch_en=seq_padding(batch_en)\n",
    "    batch_fr=seq_padding(batch_fr)\n",
    "    batches.append(Batch(batch_en,batch_fr))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4450f4fe",
   "metadata": {},
   "source": [
    "## 2.3. Word Embedding\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "05d352a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "there are 11055 distinct English tokens\n",
      "there are 11239 distinct French tokens\n"
     ]
    }
   ],
   "source": [
    "src_vocab = len(en_word_dict)\n",
    "tgt_vocab = len(fr_word_dict)\n",
    "print(f\"there are {src_vocab} distinct English tokens\")\n",
    "print(f\"there are {tgt_vocab} distinct French tokens\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "139037be",
   "metadata": {},
   "source": [
    "```python\n",
    "import math\n",
    "\n",
    "class Embeddings(nn.Module):\n",
    "    def __init__(self, d_model, vocab):\n",
    "        super().__init__()\n",
    "        self.lut = nn.Embedding(vocab, d_model)\n",
    "        self.d_model = d_model\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.lut(x) * math.sqrt(self.d_model)\n",
    "        return out\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5f17ecb",
   "metadata": {},
   "source": [
    "## 2.4. Positional Encoding\n",
    "To model the order of elements in the input and output sequences, we'll first create positional encodings of the sequences as follows:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d886689",
   "metadata": {},
   "source": [
    "```python\n",
    "class PositionalEncoding(nn.Module):\n",
    "    def __init__(self, d_model, dropout, max_len=5000):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "        pe = torch.zeros(max_len, d_model, device=DEVICE)\n",
    "        position = torch.arange(0., max_len, \n",
    "                                device=DEVICE).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(\n",
    "            0., d_model, 2, device=DEVICE)\n",
    "            * -(math.log(10000.0) / d_model))\n",
    "        pe_pos = torch.mul(position, div_term)\n",
    "        pe[:, 0::2] = torch.sin(pe_pos)\n",
    "        pe[:, 1::2] = torch.cos(pe_pos)\n",
    "        pe = pe.unsqueeze(0)\n",
    "        self.register_buffer('pe', pe)  \n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.pe[:, :x.size(1)].requires_grad_(False)\n",
    "        out = self.dropout(x)\n",
    "        return out\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9dffa6e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the shape of positional encoding is torch.Size([1, 8, 256])\n",
      "tensor([[[ 0.0000e+00,  1.1111e+00,  0.0000e+00,  ...,  1.1111e+00,\n",
      "           0.0000e+00,  1.1111e+00],\n",
      "         [ 9.3497e-01,  6.0034e-01,  8.9107e-01,  ...,  1.1111e+00,\n",
      "           1.1940e-04,  1.1111e+00],\n",
      "         [ 1.0103e+00, -4.6239e-01,  1.0646e+00,  ...,  1.1111e+00,\n",
      "           2.3880e-04,  1.1111e+00],\n",
      "         ...,\n",
      "         [-1.0655e+00,  3.1518e-01, -1.1091e+00,  ...,  1.1111e+00,\n",
      "           5.9700e-04,  1.1111e+00],\n",
      "         [-3.1046e-01,  1.0669e+00, -7.1559e-01,  ...,  1.1111e+00,\n",
      "           7.1640e-04,  1.1111e+00],\n",
      "         [ 7.2999e-01,  0.0000e+00,  2.5419e-01,  ...,  1.1111e+00,\n",
      "           8.3581e-04,  1.1111e+00]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "from utils.ch09util import PositionalEncoding\n",
    "import torch\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "pe = PositionalEncoding(256, 0.1)\n",
    "x = torch.zeros(1, 8, 256).to(DEVICE)\n",
    "y = pe.forward(x)\n",
    "print(f\"the shape of positional encoding is {y.shape}\")\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "832f3c43",
   "metadata": {},
   "source": [
    "# 3 Create A Transformer\n",
    "We'll follow the 2017 paper and create and train an encoder-decoder transformer to translate English to French. The code is adapted from the Chinese to English translator by Chris Cui (https://cuicaihao.com/the-annotated-transformer-english-to-chinese-translator/) and the German to English translator by Alexander Rush (http://nlp.seas.harvard.edu/annotated-transformer/).  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd90f9eb",
   "metadata": {},
   "source": [
    "## 3.2. The Attention Mechanism\n",
    "\n",
    "\n",
    "The *attention()* function is defined in the local module as follows:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2ec5222",
   "metadata": {},
   "source": [
    "```python\n",
    "def attention(query, key, value, mask=None, dropout=None):\n",
    "    d_k = query.size(-1)\n",
    "    scores = torch.matmul(query, \n",
    "              key.transpose(-2, -1)) / math.sqrt(d_k)\n",
    "    if mask is not None:\n",
    "        scores = scores.masked_fill(mask == 0, -1e9)\n",
    "    p_attn = nn.functional.softmax(scores, dim=-1)\n",
    "    if dropout is not None:\n",
    "        p_attn = dropout(p_attn)\n",
    "    return torch.matmul(p_attn, value), p_attn\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09d4bdc7",
   "metadata": {},
   "source": [
    "```python\n",
    "from copy import deepcopy\n",
    "class MultiHeadedAttention(nn.Module):\n",
    "    def __init__(self, h, d_model, dropout=0.1):\n",
    "        super().__init__()\n",
    "        assert d_model % h == 0\n",
    "        self.d_k = d_model // h\n",
    "        self.h = h\n",
    "        self.linears = nn.ModuleList([deepcopy(\n",
    "            nn.Linear(d_model, d_model)) for i in range(4)])\n",
    "        self.attn = None\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "\n",
    "    def forward(self, query, key, value, mask=None):\n",
    "        if mask is not None:\n",
    "            mask = mask.unsqueeze(1)\n",
    "        nbatches = query.size(0)  \n",
    "        query, key, value = [l(x).view(nbatches, -1, self.h,\n",
    "           self.d_k).transpose(1, 2)\n",
    "         for l, x in zip(self.linears, (query, key, value))]\n",
    "        x, self.attn = attention(\n",
    "            query, key, value, mask=mask, dropout=self.dropout)\n",
    "        x = x.transpose(1, 2).contiguous().view(\n",
    "            nbatches, -1, self.h * self.d_k)\n",
    "        output = self.linears[-1](x)\n",
    "        return output \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca029dbb",
   "metadata": {},
   "source": [
    "```python\n",
    "class PositionwiseFeedForward(nn.Module):\n",
    "    def __init__(self, d_model, d_ff, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.w_1 = nn.Linear(d_model, d_ff)\n",
    "        self.w_2 = nn.Linear(d_ff, d_model)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        h1 = self.w_1(x)\n",
    "        h2 = self.dropout(h1)\n",
    "        return self.w_2(h2)   \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a568725",
   "metadata": {},
   "source": [
    "## 3.2\tCreate an encoder-decoder Transformer\n",
    "To create an encoder-decoder transformer, we define a Transformer class in the local module *ch09util.py* as follows:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3f9204e",
   "metadata": {},
   "source": [
    "```python\n",
    "# An encoder-decoder transformer\n",
    "class Transformer(nn.Module):\n",
    "    def __init__(self, encoder, decoder,\n",
    "                 src_embed, tgt_embed, generator):\n",
    "        super().__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.src_embed = src_embed\n",
    "        self.tgt_embed = tgt_embed\n",
    "        self.generator = generator\n",
    "\n",
    "    def encode(self, src, src_mask):\n",
    "        return self.encoder(self.src_embed(src), src_mask)\n",
    "\n",
    "    def decode(self, memory, src_mask, tgt, tgt_mask):\n",
    "        return self.decoder(self.tgt_embed(tgt), \n",
    "                            memory, src_mask, tgt_mask)\n",
    "\n",
    "    def forward(self, src, tgt, src_mask, tgt_mask):\n",
    "        memory = self.encode(src, src_mask)\n",
    "        output = self.decode(memory, src_mask, tgt, tgt_mask)\n",
    "        return output\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "112ae1cc",
   "metadata": {},
   "source": [
    "```python\n",
    "class EncoderLayer(nn.Module):\n",
    "    def __init__(self, size, self_attn, feed_forward, dropout):\n",
    "        super().__init__()\n",
    "        self.self_attn = self_attn\n",
    "        self.feed_forward = feed_forward\n",
    "        self.sublayer = nn.ModuleList([deepcopy(\n",
    "        SublayerConnection(size, dropout)) for i in range(2)])\n",
    "        self.size = size  \n",
    "    def forward(self, x, mask):\n",
    "        x = self.sublayer[0](\n",
    "            x, lambda x: self.self_attn(x, x, x, mask))\n",
    "        output = self.sublayer[1](x, self.feed_forward)\n",
    "        return output \n",
    "    \n",
    "class SublayerConnection(nn.Module):\n",
    "    def __init__(self, size, dropout):\n",
    "        super().__init__()\n",
    "        self.norm = LayerNorm(size)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "    def forward(self, x, sublayer):\n",
    "        output = x + self.dropout(sublayer(self.norm(x)))\n",
    "        return output  \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c0f149c",
   "metadata": {},
   "source": [
    "```python\n",
    "class LayerNorm(nn.Module):\n",
    "    def __init__(self, features, eps=1e-6):\n",
    "        super().__init__()\n",
    "        self.a_2 = nn.Parameter(torch.ones(features))\n",
    "        self.b_2 = nn.Parameter(torch.zeros(features))\n",
    "        self.eps = eps\n",
    "    def forward(self, x):\n",
    "        mean = x.mean(-1, keepdim=True) \n",
    "        std = x.std(-1, keepdim=True)\n",
    "        x_zscore = (x - mean) / torch.sqrt(std ** 2 + self.eps)\n",
    "        output = self.a_2*x_zscore+self.b_2\n",
    "        return output\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89828266",
   "metadata": {},
   "source": [
    "The encoder consists of N=6 identical encoder layers. The *Encoder* class is defined as follows in the local module: "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e76d313b",
   "metadata": {},
   "source": [
    "```python\n",
    "# Create an encoder\n",
    "from copy import deepcopy\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(self, layer, N):\n",
    "        super().__init__()\n",
    "        self.layers = nn.ModuleList(\n",
    "            [deepcopy(layer) for i in range(N)])\n",
    "        self.norm = LayerNorm(layer.size)\n",
    "\n",
    "    def forward(self, x, mask):\n",
    "        for layer in self.layers:\n",
    "            x = layer(x, mask)\n",
    "            output = self.norm(x)\n",
    "        return output\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85c45299",
   "metadata": {},
   "source": [
    "```python\n",
    "# Create a decoder\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, layer, N):\n",
    "        super().__init__()\n",
    "        self.layers = nn.ModuleList(\n",
    "            [deepcopy(layer) for i in range(N)])\n",
    "        self.norm = LayerNorm(layer.size)\n",
    "\n",
    "    def forward(self, x, memory, src_mask, tgt_mask):\n",
    "        for layer in self.layers:\n",
    "            x = layer(x, memory, src_mask, tgt_mask)\n",
    "        output = self.norm(x)\n",
    "        return output\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd150da1",
   "metadata": {},
   "source": [
    "```python\n",
    "class DecoderLayer(nn.Module):\n",
    "    def __init__(self, size, self_attn, src_attn,\n",
    "                 feed_forward, dropout):\n",
    "        super().__init__()\n",
    "        self.size = size\n",
    "        self.self_attn = self_attn\n",
    "        self.src_attn = src_attn\n",
    "        self.feed_forward = feed_forward\n",
    "        self.sublayer = nn.ModuleList([deepcopy(\n",
    "        SublayerConnection(size, dropout)) for i in range(3)])\n",
    "\n",
    "    def forward(self, x, memory, src_mask, tgt_mask):\n",
    "        x = self.sublayer[0](x, lambda x: \n",
    "                 self.self_attn(x, x, x, tgt_mask))\n",
    "        x = self.sublayer[1](x, lambda x:\n",
    "                 self.src_attn(x, memory, memory, src_mask))\n",
    "        output = self.sublayer[2](x, self.feed_forward)\n",
    "        return output \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11199cf5",
   "metadata": {},
   "source": [
    "## 3.4. Put All Pieces Together\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "090eafd8",
   "metadata": {},
   "source": [
    "```python\n",
    "class Generator(nn.Module):\n",
    "    def __init__(self, d_model, vocab):\n",
    "        super().__init__()\n",
    "        self.proj = nn.Linear(d_model, vocab)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.proj(x)\n",
    "        probs = nn.functional.log_softmax(out, dim=-1)\n",
    "        return probs  \n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "270bfb7f",
   "metadata": {},
   "source": [
    "```python\n",
    "# create the model\n",
    "def create_model(src_vocab, tgt_vocab, N, d_model,\n",
    "                 d_ff, h, dropout=0.1):\n",
    "    attn=MultiHeadedAttention(h, d_model).to(DEVICE)\n",
    "    ff=PositionwiseFeedForward(d_model, d_ff, dropout).to(DEVICE)\n",
    "    pos=PositionalEncoding(d_model, dropout).to(DEVICE)\n",
    "    model = Transformer(\n",
    "        Encoder(EncoderLayer(d_model,deepcopy(attn),deepcopy(ff),\n",
    "                             dropout).to(DEVICE),N).to(DEVICE),\n",
    "        Decoder(DecoderLayer(d_model,deepcopy(attn),\n",
    "             deepcopy(attn),deepcopy(ff), dropout).to(DEVICE),\n",
    "                N).to(DEVICE),\n",
    "        nn.Sequential(Embeddings(d_model, src_vocab).to(DEVICE),\n",
    "                      deepcopy(pos)),\n",
    "        nn.Sequential(Embeddings(d_model, tgt_vocab).to(DEVICE),\n",
    "                      deepcopy(pos)),\n",
    "        Generator(d_model, tgt_vocab)).to(DEVICE)\n",
    "    for p in model.parameters():\n",
    "        if p.dim() > 1:\n",
    "            nn.init.xavier_uniform_(p)\n",
    "    return model.to(DEVICE)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "22e51429",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.ch09util import create_model\n",
    "\n",
    "model = create_model(src_vocab, tgt_vocab, N=6,\n",
    "    d_model=256, d_ff=1024, h=8, dropout=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0df1b8ba",
   "metadata": {},
   "source": [
    "# 4. Train the Transformer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4768df6",
   "metadata": {},
   "source": [
    "## 4.1 Loss Function and Optimizer\n",
    "\n",
    "\n",
    "We define the following class in the local module:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77c8742a",
   "metadata": {},
   "source": [
    "```python\n",
    "class LabelSmoothing(nn.Module):\n",
    "    def __init__(self, size, padding_idx, smoothing=0.1):\n",
    "        super().__init__()\n",
    "        self.criterion = nn.KLDivLoss(reduction='sum')  \n",
    "        self.padding_idx = padding_idx\n",
    "        self.confidence = 1.0 - smoothing\n",
    "        self.smoothing = smoothing\n",
    "        self.size = size\n",
    "        self.true_dist = None\n",
    "\n",
    "    def forward(self, x, target):\n",
    "        assert x.size(1) == self.size\n",
    "        true_dist = x.data.clone()\n",
    "        true_dist.fill_(self.smoothing / (self.size - 2))\n",
    "        true_dist.scatter_(1, \n",
    "               target.data.unsqueeze(1), self.confidence)\n",
    "        true_dist[:, self.padding_idx] = 0\n",
    "        mask = torch.nonzero(target.data == self.padding_idx)\n",
    "        if mask.dim() > 0:\n",
    "            true_dist.index_fill_(0, mask.squeeze(), 0.0)\n",
    "        self.true_dist = true_dist\n",
    "        output = self.criterion(x, true_dist.clone().detach())\n",
    "        return output\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61cd3b51",
   "metadata": {},
   "source": [
    "```python\n",
    "class NoamOpt:\n",
    "    def __init__(self, model_size, factor, warmup, optimizer):\n",
    "        self.optimizer = optimizer\n",
    "        self._step = 0\n",
    "        self.warmup = warmup\n",
    "        self.factor = factor\n",
    "        self.model_size = model_size\n",
    "        self._rate = 0\n",
    "\n",
    "    def step(self):\n",
    "        self._step += 1\n",
    "        rate = self.rate()\n",
    "        for p in self.optimizer.param_groups:\n",
    "            p['lr'] = rate\n",
    "        self._rate = rate\n",
    "        self.optimizer.step()\n",
    "\n",
    "    def rate(self, step=None):\n",
    "        if step is None:\n",
    "            step = self._step\n",
    "        output = self.factor * (self.model_size ** (-0.5) *\n",
    "        min(step ** (-0.5), step * self.warmup ** (-1.5)))\n",
    "        return output\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a073dab3",
   "metadata": {},
   "source": [
    "We create the optimizer for training as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "29a4ab8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.ch09util import NoamOpt\n",
    "\n",
    "optimizer = NoamOpt(256, 1, 2000, torch.optim.Adam(\n",
    "    model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b2434c4",
   "metadata": {},
   "source": [
    "To create the loss function for training, we first define the following class in the local module:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4602ce37",
   "metadata": {},
   "source": [
    "```python\n",
    "class SimpleLossCompute:\n",
    "    def __init__(self, generator, criterion, opt=None):\n",
    "        self.generator = generator\n",
    "        self.criterion = criterion\n",
    "        self.opt = opt\n",
    "\n",
    "    def __call__(self, x, y, norm):\n",
    "        x = self.generator(x)\n",
    "        loss = self.criterion(x.contiguous().view(-1, x.size(-1)),\n",
    "                              y.contiguous().view(-1)) / norm\n",
    "        loss.backward()\n",
    "        if self.opt is not None:\n",
    "            self.opt.step()\n",
    "            self.opt.optimizer.zero_grad()\n",
    "        return loss.data.item() * norm.float()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5163f7f",
   "metadata": {},
   "source": [
    "We then define the loss function as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "5b0f9cc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.ch09util import (LabelSmoothing,\n",
    "       SimpleLossCompute)\n",
    "\n",
    "criterion = LabelSmoothing(tgt_vocab, \n",
    "                           padding_idx=0, smoothing=0.1)\n",
    "loss_func = SimpleLossCompute(\n",
    "            model.generator, criterion, optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "062599ec",
   "metadata": {},
   "source": [
    "\n",
    "We'll train the model for 100 epochs. We'll calculate the loss and the number of tokens from each batch. After each epoch, we calculate the average loss in the epoch as the ratio between the total loss and the total number of tokens:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "0146ee51",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train for 100 epochs\n",
    "for epoch in range(100):\n",
    "    model.train()\n",
    "    tloss=0\n",
    "    tokens=0\n",
    "    for batch in batches:\n",
    "        out = model(batch.src, batch.trg, \n",
    "                    batch.src_mask, batch.trg_mask)\n",
    "        loss = loss_func(out, batch.trg_y, batch.ntokens)\n",
    "        tloss += loss\n",
    "        tokens += batch.ntokens\n",
    "    print(f\"Epoch {epoch}, average loss: {tloss/tokens}\")\n",
    "torch.save(model.state_dict(),\"files/en2fr.pth\")   "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3274397",
   "metadata": {},
   "source": [
    "The above training process takes a couple of hours if you are using a GPU. It may take several hours if you are using CPU training. Once the training is done, the model weights are saved as *en2fr.pth* on your computer. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f83ef71",
   "metadata": {},
   "source": [
    "## 4.3. Translate English to French with the Trained Model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "a4779970",
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate(eng):\n",
    "    # tokenize the English sentence\n",
    "    tokenized_en=tokenizer.tokenize(eng)\n",
    "    # add beginning and end tokens\n",
    "    tokenized_en=[\"BOS\"]+tokenized_en+[\"EOS\"]\n",
    "    # convert tokens to indexes\n",
    "    enidx=[en_word_dict.get(i,UNK) for i in tokenized_en]  \n",
    "    src=torch.tensor(enidx).long().to(DEVICE).unsqueeze(0)\n",
    "    # create mask to hide padding\n",
    "    src_mask=(src!=0).unsqueeze(-2)\n",
    "    # encode the English sentence\n",
    "    memory=model.encode(src,src_mask)\n",
    "    # start translation in an autogressive fashion\n",
    "    start_symbol=fr_word_dict[\"BOS\"]\n",
    "    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)\n",
    "    translation=[]\n",
    "    for i in range(100):\n",
    "        out = model.decode(memory,src_mask,ys,\n",
    "        subsequent_mask(ys.size(1)).type_as(src.data))\n",
    "        prob = model.generator(out[:, -1])\n",
    "        _, next_word = torch.max(prob, dim=1)\n",
    "        next_word = next_word.data[0]\n",
    "        ys = torch.cat([ys, torch.ones(1, 1).type_as(\n",
    "            src.data).fill_(next_word)], dim=1)\n",
    "        sym = fr_idx_dict[ys[0, -1].item()]\n",
    "        if sym != 'EOS':\n",
    "            translation.append(sym)\n",
    "        else:\n",
    "            break\n",
    "    # convert tokens to sentences\n",
    "    trans=\"\".join(translation)\n",
    "    trans=trans.replace(\"</w>\",\" \") \n",
    "    for x in '''?:;.,'(\"-!&)%''':\n",
    "        trans=trans.replace(f\" {x}\",f\"{x}\")    \n",
    "    print(trans)\n",
    "    return trans"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3e8c99d",
   "metadata": {},
   "source": [
    "Let's try the defined function on the English phrase \"Today is a beautiful day!\", like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5a2af177",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "aujourd'hui est une belle journee! \n"
     ]
    }
   ],
   "source": [
    "from utils.ch09util import subsequent_mask\n",
    "\n",
    "with open(\"files/dict.p\",\"rb\") as fb:\n",
    "    en_word_dict,en_idx_dict,\\\n",
    "    fr_word_dict,fr_idx_dict=pickle.load(fb)\n",
    "trained_weights=torch.load(\"files/en2fr.pth\",\n",
    "                           map_location=DEVICE)\n",
    "model.load_state_dict(trained_weights)\n",
    "model.eval()\n",
    "eng = \"Today is a beautiful day!\"\n",
    "translated_fr = translate(eng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "6c4fea34",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "un petit garcon en jeans grimpe un petit arbre tandis qu'un autre enfant regarde. \n"
     ]
    }
   ],
   "source": [
    "eng = \"A little boy in jeans climbs a small tree while another child looks on.\"\n",
    "translated_fr = translate(eng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "03d060f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "je ne parle pas francais. \n"
     ]
    }
   ],
   "source": [
    "eng = \"I don't speak French.\"\n",
    "translated_fr = translate(eng)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f342772",
   "metadata": {},
   "source": [
    "Now let's try the sentence \"I do not speak French.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "f39f8649",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "je ne parle pas francais. \n"
     ]
    }
   ],
   "source": [
    "eng = \"I do not speak French.\"\n",
    "translated_fr = translate(eng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "279e4604",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "j'aime le ski en hiver! \n",
      "comment etes-vous? \n"
     ]
    }
   ],
   "source": [
    "# exercise 9.3\n",
    "eng = \"I love skiing in the winter!\"\n",
    "translated_fr = translate(eng)\n",
    "eng = \"How are you?\"\n",
    "translated_fr = translate(eng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "2bac19f8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
